/*
 * Decompiled with CFR 0.152.
 */
package software.bernie.geckolib.loading.math;

import com.google.gson.JsonElement;
import com.google.gson.JsonPrimitive;
import com.mojang.datafixers.util.Either;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import java.util.function.ToDoubleFunction;
import java.util.regex.Pattern;
import net.minecraft.class_156;
import org.apache.logging.log4j.Level;
import org.jspecify.annotations.Nullable;
import software.bernie.geckolib.GeckoLibConstants;
import software.bernie.geckolib.animation.state.ControllerState;
import software.bernie.geckolib.loading.math.MathValue;
import software.bernie.geckolib.loading.math.MolangQueries;
import software.bernie.geckolib.loading.math.Operator;
import software.bernie.geckolib.loading.math.function.MathFunction;
import software.bernie.geckolib.loading.math.function.generic.ACosFunction;
import software.bernie.geckolib.loading.math.function.generic.ASinFunction;
import software.bernie.geckolib.loading.math.function.generic.ATan2Function;
import software.bernie.geckolib.loading.math.function.generic.ATanFunction;
import software.bernie.geckolib.loading.math.function.generic.AbsFunction;
import software.bernie.geckolib.loading.math.function.generic.CosFunction;
import software.bernie.geckolib.loading.math.function.generic.ExpFunction;
import software.bernie.geckolib.loading.math.function.generic.LogFunction;
import software.bernie.geckolib.loading.math.function.generic.ModFunction;
import software.bernie.geckolib.loading.math.function.generic.PowFunction;
import software.bernie.geckolib.loading.math.function.generic.SinFunction;
import software.bernie.geckolib.loading.math.function.generic.SqrtFunction;
import software.bernie.geckolib.loading.math.function.limit.ClampFunction;
import software.bernie.geckolib.loading.math.function.limit.MaxFunction;
import software.bernie.geckolib.loading.math.function.limit.MinFunction;
import software.bernie.geckolib.loading.math.function.misc.PiFunction;
import software.bernie.geckolib.loading.math.function.misc.ToDegFunction;
import software.bernie.geckolib.loading.math.function.misc.ToRadFunction;
import software.bernie.geckolib.loading.math.function.random.DieRollFunction;
import software.bernie.geckolib.loading.math.function.random.DieRollIntegerFunction;
import software.bernie.geckolib.loading.math.function.random.RandomFunction;
import software.bernie.geckolib.loading.math.function.random.RandomIntegerFunction;
import software.bernie.geckolib.loading.math.function.round.CeilFunction;
import software.bernie.geckolib.loading.math.function.round.FloorFunction;
import software.bernie.geckolib.loading.math.function.round.HermiteBlendFunction;
import software.bernie.geckolib.loading.math.function.round.LerpFunction;
import software.bernie.geckolib.loading.math.function.round.LerpRotFunction;
import software.bernie.geckolib.loading.math.function.round.RoundFunction;
import software.bernie.geckolib.loading.math.function.round.TruncateFunction;
import software.bernie.geckolib.loading.math.value.BooleanNegate;
import software.bernie.geckolib.loading.math.value.Calculation;
import software.bernie.geckolib.loading.math.value.CompoundValue;
import software.bernie.geckolib.loading.math.value.Constant;
import software.bernie.geckolib.loading.math.value.Group;
import software.bernie.geckolib.loading.math.value.Negative;
import software.bernie.geckolib.loading.math.value.Ternary;
import software.bernie.geckolib.loading.math.value.Variable;
import software.bernie.geckolib.loading.math.value.VariableAssignment;
import software.bernie.geckolib.object.CompoundException;

public class MathParser {
    private static final Pattern EXPRESSION_FORMAT = Pattern.compile("^[\\w\\s_+-/*%^&|<>=!?:.,()]+$");
    private static final Pattern WHITESPACE = Pattern.compile("\\s");
    private static final Pattern NUMERIC = Pattern.compile("^-?\\d+(\\.\\d+)?$");
    private static final Pattern VALID_DOUBLE = Pattern.compile("[\\x00-\\x20]*[+-]?(NaN|Infinity|((((\\d+)(\\.)?((\\d+)?)([eE][+-]?(\\d+))?)|(\\.(\\d+)([eE][+-]?(\\d+))?)|(((0[xX](\\p{XDigit}+)(\\.)?)|(0[xX](\\p{XDigit}+)?(\\.)(\\p{XDigit}+)))[pP][+-]?(\\d+)))[fFdD]?))[\\x00-\\x20]*");
    private static final String MOLANG_RETURN = "return ";
    private static final String STATEMENT_DELIMITER = ";";
    private static final Map<String, MathFunction.Factory<?>> FUNCTION_FACTORIES = (Map)class_156.method_654(new ConcurrentHashMap(18), map -> {
        map.put("math.abs", AbsFunction::new);
        map.put("math.acos", ACosFunction::new);
        map.put("math.asin", ASinFunction::new);
        map.put("math.atan", ATanFunction::new);
        map.put("math.atan2", ATan2Function::new);
        map.put("math.ceil", CeilFunction::new);
        map.put("math.clamp", ClampFunction::new);
        map.put("math.cos", CosFunction::new);
        map.put("math.die_roll", DieRollFunction::new);
        map.put("math.die_roll_integer", DieRollIntegerFunction::new);
        map.put("math.exp", ExpFunction::new);
        map.put("math.floor", FloorFunction::new);
        map.put("math.hermite_blend", HermiteBlendFunction::new);
        map.put("math.lerp", LerpFunction::new);
        map.put("math.lerprotate", LerpRotFunction::new);
        map.put("math.ln", LogFunction::new);
        map.put("math.max", MaxFunction::new);
        map.put("math.min", MinFunction::new);
        map.put("math.mod", ModFunction::new);
        map.put("math.pi", PiFunction::new);
        map.put("math.pow", PowFunction::new);
        map.put("math.random", RandomFunction::new);
        map.put("math.random_integer", RandomIntegerFunction::new);
        map.put("math.round", RoundFunction::new);
        map.put("math.sin", SinFunction::new);
        map.put("math.sqrt", SqrtFunction::new);
        map.put("math.to_deg", ToDegFunction::new);
        map.put("math.to_rad", ToRadFunction::new);
        map.put("math.trunc", TruncateFunction::new);
    });

    public static boolean isFunctionRegistered(String name) {
        return FUNCTION_FACTORIES.containsKey(name);
    }

    public static void registerFunction(String name, MathFunction.Factory<?> factory) {
        if (FUNCTION_FACTORIES.put(name, factory) != null) {
            GeckoLibConstants.LOGGER.log(Level.WARN, "Duplicate registration of MathFunction: '{}'. Ignore if intentional override", (Object)name);
        }
        GeckoLibConstants.LOGGER.log(Level.DEBUG, "Registered MathFunction '{}'", (Object)name);
    }

    public static <T extends MathFunction> Optional<T> buildFunction(String name, MathValue ... values) {
        if (!FUNCTION_FACTORIES.containsKey(name)) {
            return Optional.empty();
        }
        return Optional.of(FUNCTION_FACTORIES.get(name).create(values));
    }

    public static void registerVariable(Variable variable) {
        MolangQueries.registerVariable(variable);
    }

    public static Variable getVariableFor(String name) {
        return MolangQueries.getVariableFor(name);
    }

    public static void setVariable(String name, ToDoubleFunction<ControllerState> value) {
        MathParser.getVariableFor(name).set(value);
    }

    public static MathValue parseJson(JsonElement element) {
        JsonPrimitive primitive;
        if (!(element instanceof JsonPrimitive) || (primitive = (JsonPrimitive)element).isBoolean()) {
            throw new CompoundException("Bad formatting on Molang expression, expected single value, received: " + element.getClass().getSimpleName());
        }
        if (primitive.isNumber()) {
            return new Constant(primitive.getAsDouble());
        }
        if (primitive.isString()) {
            String value = primitive.getAsString();
            if (VALID_DOUBLE.matcher(value).matches()) {
                return new Constant(Double.parseDouble(value));
            }
            return MathParser.compileMolang(value);
        }
        return new Constant(0.0);
    }

    public static MathValue compileMolang(String expression) {
        if (expression.startsWith(MOLANG_RETURN)) {
            if ((expression = expression.substring(MOLANG_RETURN.length())).contains(STATEMENT_DELIMITER)) {
                expression = expression.substring(0, expression.indexOf(STATEMENT_DELIMITER));
            }
        } else if (expression.contains(STATEMENT_DELIMITER)) {
            String[] subExpressions = expression.split(STATEMENT_DELIMITER);
            ObjectArrayList subValues = new ObjectArrayList(subExpressions.length);
            for (String subExpression : subExpressions) {
                boolean isReturn = subExpression.startsWith(MOLANG_RETURN);
                if (isReturn) {
                    subExpression = subExpression.substring(MOLANG_RETURN.length());
                }
                subValues.add(MathParser.compileExpression(subExpression));
                if (isReturn) break;
            }
            return new CompoundValue(subValues.toArray(new MathValue[0]));
        }
        return MathParser.compileExpression(expression);
    }

    public static MathValue compileExpression(String expression) {
        try {
            return MathParser.parseSymbols(MathParser.compileSymbols(MathParser.decomposeExpression(expression)));
        }
        catch (CompoundException ex) {
            throw ex.withMessage("Failed to parse expression '" + expression + "'");
        }
    }

    public static char[] decomposeExpression(String expression) throws CompoundException {
        if (!EXPRESSION_FORMAT.matcher(expression).matches()) {
            throw new CompoundException("Invalid characters found in expression: '" + expression + "'");
        }
        char[] chars = WHITESPACE.matcher(expression).replaceAll("").toLowerCase(Locale.ROOT).toCharArray();
        int groupState = 0;
        for (char character : chars) {
            if (character == '(') {
                ++groupState;
            } else if (character == ')') {
                --groupState;
            }
            if (groupState >= 0) continue;
            throw new CompoundException("Closing parenthesis before opening parenthesis in expression '" + expression + "'");
        }
        if (groupState != 0) {
            throw new CompoundException("Uneven parenthesis in expression, each opening brace must have a pairing close brace '" + expression + "'");
        }
        return chars;
    }

    protected static @Nullable String tryMergeOperativeSymbols(char[] chars, int index) {
        int maxLength;
        char ch = chars[index];
        if (!Operator.isOperativeSymbol(ch)) {
            return null;
        }
        for (int length = maxLength = Math.min(chars.length - index, Operator.maxOperatorLength()); length > 0; --length) {
            String testOperator = String.copyValueOf(chars, index, length);
            if (!Operator.isOperator(testOperator)) continue;
            return testOperator;
        }
        if (ch == '?' || ch == ':' || ch == ',') {
            return String.valueOf(ch);
        }
        return null;
    }

    public static List<Either<String, List<MathValue>>> compileSymbols(char[] chars) {
        ObjectArrayList symbols = new ObjectArrayList();
        StringBuilder buffer = new StringBuilder();
        int lastSymbolIndex = -1;
        block0: for (int i = 0; i < chars.length; ++i) {
            char ch = chars[i];
            if (ch == '-' && buffer.isEmpty() && (symbols.isEmpty() || lastSymbolIndex == symbols.size() - 1)) {
                buffer.append(ch);
                continue;
            }
            String operator = MathParser.tryMergeOperativeSymbols(chars, i);
            if (operator != null) {
                i += operator.length() - 1;
                if (!buffer.isEmpty()) {
                    symbols.add(Either.left((Object)buffer.toString()));
                }
                lastSymbolIndex = symbols.size();
                symbols.add(Either.left((Object)operator));
                buffer.setLength(0);
                continue;
            }
            if (ch == '(') {
                if (!buffer.isEmpty()) {
                    symbols.add(Either.left((Object)buffer.toString()));
                    buffer.setLength(0);
                }
                ObjectArrayList subValues = new ObjectArrayList();
                int groupState = 1;
                for (int j = i + 1; j < chars.length; ++j) {
                    char groupChar = chars[j];
                    if (groupChar == '(') {
                        ++groupState;
                    } else if (groupChar == ')') {
                        --groupState;
                    } else if (groupChar == ',' && groupState == 1) {
                        subValues.add(MathParser.parseSymbols(MathParser.compileSymbols(buffer.toString().toCharArray())));
                        buffer.setLength(0);
                        continue;
                    }
                    if (groupState == 0) {
                        if (!buffer.isEmpty()) {
                            subValues.add(MathParser.parseSymbols(MathParser.compileSymbols(buffer.toString().toCharArray())));
                        }
                        i = j;
                        symbols.add(Either.right((Object)subValues));
                        buffer.setLength(0);
                        continue block0;
                    }
                    buffer.append(groupChar);
                }
                continue;
            }
            buffer.append(ch);
        }
        if (!buffer.isEmpty()) {
            symbols.add(Either.left((Object)buffer.toString()));
        }
        return symbols;
    }

    public static MathValue parseSymbols(List<Either<String, List<MathValue>>> symbols) throws CompoundException {
        if (symbols.size() == 2) {
            Optional<String> prefix = symbols.getFirst().left().filter(left -> left.startsWith("-") || left.startsWith("!") || MathParser.isFunctionRegistered(left));
            Optional group = symbols.get(1).right();
            if (prefix.isPresent() && group.isPresent()) {
                Optional<? extends MathValue> value = MathParser.compileFunction(prefix.get(), (List)group.get());
                return value.orElseThrow(() -> new CompoundException("Unable to parse function '" + (String)prefix.get() + "' with arguments: " + String.valueOf(group.get())));
            }
        }
        return MathParser.compileValue(symbols).orElseThrow(() -> new CompoundException("Unable to parse compiled symbols from expression: " + String.valueOf(symbols)));
    }

    protected static Optional<? extends MathValue> compileValue(List<Either<String, List<MathValue>>> symbols) throws CompoundException {
        if (symbols.size() == 1) {
            return MathParser.compileSingleValue(symbols.getFirst());
        }
        Optional<Ternary> ternary = MathParser.compileTernary(symbols);
        if (ternary.isPresent()) {
            return ternary;
        }
        return MathParser.compileCalculation(symbols);
    }

    protected static Optional<MathValue> compileSingleValue(Either<String, List<MathValue>> symbol) throws CompoundException {
        if (symbol.right().isPresent()) {
            return Optional.of(new Group((MathValue)((List)symbol.right().get()).getFirst()));
        }
        return symbol.left().map(string -> {
            if (string.startsWith("!")) {
                return MathParser.compileSingleValue((Either<String, List<MathValue>>)Either.left((Object)string.substring(1))).map(BooleanNegate::new).orElse(null);
            }
            if (MathParser.isNumeric(string)) {
                return new Constant(Double.parseDouble(string));
            }
            if (MathParser.isLikelyVariable(string)) {
                if (string.startsWith("-")) {
                    return new Negative(MathParser.getVariableFor(string.substring(1)));
                }
                return MathParser.getVariableFor(string);
            }
            if (MathParser.isFunctionRegistered(string)) {
                return MathParser.compileFunction(string, List.of()).orElse(null);
            }
            return null;
        });
    }

    protected static Optional<MathValue> compileCalculation(List<Either<String, List<MathValue>>> symbols) throws CompoundException {
        int symbolCount = symbols.size();
        int operatorIndex = -1;
        Operator lastOperator = null;
        for (int i = 1; i < symbolCount; ++i) {
            Operator operator = symbols.get(i).left().filter(Operator::isOperator).map(MathParser::getOperatorFor).orElse(null);
            if (operator == null) continue;
            if (operator == Operator.ASSIGN_VARIABLE) {
                MathValue mathValue = MathParser.parseSymbols(symbols.subList(0, i));
                if (!(mathValue instanceof Variable)) {
                    throw new CompoundException("Attempted to assign a value to a non-variable");
                }
                Variable variable = (Variable)mathValue;
                return Optional.of(new VariableAssignment(variable, MathParser.parseSymbols(symbols.subList(i + 1, symbolCount))));
            }
            if (lastOperator != null && operator.takesPrecedenceOver(lastOperator)) break;
            operatorIndex = i;
            lastOperator = operator;
        }
        return lastOperator == null ? Optional.empty() : Optional.of(new Calculation(lastOperator, MathParser.parseSymbols(symbols.subList(0, operatorIndex)), MathParser.parseSymbols(symbols.subList(operatorIndex + 1, symbolCount))));
    }

    protected static Optional<Ternary> compileTernary(List<Either<String, List<MathValue>>> symbols) throws CompoundException {
        int symbolCount = symbols.size();
        if (symbolCount < 3) {
            return Optional.empty();
        }
        Supplier<MathValue> condition = null;
        Supplier<MathValue> ifTrue = null;
        int ternaryState = 0;
        int lastColon = -1;
        int queryIndex = -1;
        for (int i = 0; i < symbolCount; ++i) {
            int i2 = i;
            String string = symbols.get(i).left().orElse(null);
            if ("?".equals(string)) {
                if (condition == null) {
                    condition = () -> MathParser.parseSymbols(symbols.subList(0, i2));
                    queryIndex = i2 + 1;
                }
                ++ternaryState;
                continue;
            }
            if (!":".equals(string)) continue;
            if (ternaryState == 1 && ifTrue == null && queryIndex > 0) {
                int queryIndex2 = queryIndex;
                ifTrue = () -> MathParser.parseSymbols(symbols.subList(queryIndex2, i2));
            }
            --ternaryState;
            lastColon = i;
        }
        if (ternaryState == 0 && condition != null && ifTrue != null && lastColon < symbolCount - 1) {
            return Optional.of(new Ternary((MathValue)condition.get(), (MathValue)ifTrue.get(), MathParser.parseSymbols(symbols.subList(lastColon + 1, symbolCount))));
        }
        return Optional.empty();
    }

    protected static Optional<? extends MathValue> compileFunction(String name, List<MathValue> args) throws CompoundException {
        if (name.startsWith("!")) {
            if (name.length() == 1) {
                return Optional.of(new BooleanNegate(args.getFirst()));
            }
            return MathParser.compileFunction(name.substring(1), args).map(BooleanNegate::new);
        }
        if (name.startsWith("-")) {
            if (name.length() == 1) {
                return Optional.of(new Negative(args.getFirst()));
            }
            return MathParser.compileFunction(name.substring(1), args).map(Negative::new);
        }
        if (!MathParser.isFunctionRegistered(name)) {
            return Optional.empty();
        }
        return MathParser.buildFunction(name, args.toArray(new MathValue[0]));
    }

    public static boolean isNumeric(String string) {
        return NUMERIC.matcher(string).matches();
    }

    protected static Operator getOperatorFor(String op) throws CompoundException {
        return Operator.getOperatorFor(op).orElseThrow(() -> new CompoundException("Unknown operator symbol '" + op + "'"));
    }

    protected static boolean isLikelyVariable(String string) {
        if (MolangQueries.isExistingVariable(string)) {
            return true;
        }
        return !MathParser.isNumeric(string) && !MathParser.isFunctionRegistered(string) && !Operator.isOperator(string) && !string.equals("?") && !string.equals(":");
    }
}

